import os
import click
import torch
import torch.distributed as dist
import yaml
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from copy import deepcopy
from collections import OrderedDict
import torchvision.utils

import utils.graph_lib
import utils.samplers
import wandb
from models.model_utils import get_model, get_preconditioned_model
from utils.losses import get_loss
from utils.misc import dotdict
from utils.optimizers import WarmUpScheduler

# This makes training on A100s faster
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def init_wandb(opts):
    wandb.init(
        # set the wandb project where this run will be logged
        project='discrete-diffusion',
        name= f'{opts.model}-{opts.dataset}',
        tags= ['training',opts.dataset],
        # # track hyperparameters and run metadata
        config=opts
    )

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
        
def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def get_dataset(name):
    return MNIST('.', transform=T.PILToTensor(), download=True)

@click.command()
@click.option('--dataset',type=click.Choice(['mnist']), default='mnist')
@click.option('--model',type=click.Choice(['uvit']), default='uvit')
@click.option('--optimizer',type=click.Choice(['adam','adamw']), default='adam')
@click.option('--lr', type=float, default=1e-4)
@click.option('--ema_beta', type=float, default=.9999)
@click.option('--batch_size', type=int, default=32)
@click.option('--log_rate',type=int,default=5000)
@click.option('--num_iters',type=int,default=600000)
@click.option('--warmup_iters',type=int,default=2500)
@click.option('--num_workers',type=int,default=2)
@click.option('--seed',type=int,default=42)
@click.option('--dir',type=str)
@click.option('--net_config_path',type=str, default='configs/uvit.yaml')
@click.option('--load_checkpoint',type=str, help='Directory where we can find the desired checkpoints')
@click.option('--enable_wandb', is_flag=True, default=False)
def training(**opts):
    opts = dotdict(opts)
    batch_size = opts.batch_size
    
    dist.init_process_group('nccl')
    world_size = dist.get_world_size()
    assert batch_size % world_size == 0, "Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = opts.seed * world_size + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")

    
    net_opts = dotdict(yaml.safe_load(open(opts.net_config_path)))
    if rank == 0:
        print(opts)
        print(net_opts)
        
    wandb_enabled = opts.enable_wandb and rank == 0 # We only want to log once
    if wandb_enabled:
        init_wandb(opts)
        wandb.config.update(net_opts)
    
    dataset = get_dataset('mnist') 

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        seed=opts.seed
    )
    dataloader = DataLoader(dataset, batch_size=batch_size // world_size, shuffle=False,
                sampler=sampler, drop_last=True, num_workers=opts.num_workers, pin_memory=True)
       
    vocab_size = 256
    context_len = 784
    graph = utils.graph_lib.Absorbing(vocab_size)

    model = get_model(opts.model,vocab_size + 1, context_len, net_opts)
    ema = deepcopy(model)
    model = get_preconditioned_model(model,graph).to(device)
    ema = get_preconditioned_model(ema,graph).to(device)
    opt = torch.optim.AdamW(model.parameters(),lr=opts.lr,)
    scheduler = WarmUpScheduler(opt, opts.warmup_iters)
    scaler = torch.amp.GradScaler(device)
    
    
    start_iter = 0
    dist.barrier()
    if opts.load_checkpoint is not None:
        start_iter = load_checkpoint(opts, rank, device, model, ema, opt, scheduler)

    dist.barrier()
    
    model.train()
    model = DDP(model)
    
    if rank == 0:
        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1e6} M")
    
    if not os.path.exists(opts.dir) and rank == 0:
        os.makedirs(opts.dir)

    num_iters = opts.num_iters

    loss_fn = get_loss(graph)
    
    training_iter = start_iter
    iters_per_epoch = len(dataset)//batch_size + 1
    log_rate = opts.log_rate
    epochs = num_iters//iters_per_epoch + 1
    for epoch in range(epochs):
        pbar = tqdm(dataloader,total=iters_per_epoch,leave=False) if rank == 0 else dataloader
        for data_ in pbar:
            if training_iter > num_iters:
                break
            data, labels = data_

            data = data.to(device=device, dtype=torch.long)
            labels = labels.to(device=device, dtype=torch.long).unsqueeze(-1)
            
            opt.zero_grad()
            
            loss = loss_fn(model=model, data=data, cond=labels)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            
            for param in model.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
            
            scaler.step(opt)
            scaler.update()
            update_ema(ema.net, model.module.net, opts.ema_beta)
            scheduler.step()
            
            training_iter += 1
            
            dist.all_reduce(loss, op=dist.ReduceOp.SUM)
            loss = loss.detach().item()/world_size
            
            
            if rank == 0:
                pbar.set_description(f'Epoch {epoch}/{epochs} --- Iter {training_iter} --- Loss : {loss :6.2f}')
                    
            if wandb_enabled:
                wandb.log({
                'loss': loss})
            dist.barrier()
            # Evaluate sample accuracy
            if training_iter%log_rate == 0 or training_iter == num_iters:
                path = os.path.join(opts.dir, f'itr_{training_iter}/')
                path_samples = os.path.join(path,'samples/')
                path_samples_ema = os.path.join(path,'samples_ema/')
                os.makedirs(path_samples,exist_ok=True)
                os.makedirs(path_samples_ema,exist_ok=True)
                if rank == 0:
                    save_ckpt(model, ema, opt, scheduler, os.path.join(path, 'snapshot.pt'))
                model.eval()
                dist.barrier()
                
                n_samples = 9
                encoded_im = utils.samplers.get_pc_sampler(model, (n_samples,context_len), labels[:n_samples],100,device=device, graph=graph)
                new_data_ema = utils.samplers.get_pc_sampler(ema, (n_samples,context_len), labels[:n_samples],100,device=device, graph=graph)

                encoded_im = encoded_im.reshape(-1, 1, 28, 28)
                new_data_ema = new_data_ema.reshape(-1, 1, 28, 28)

                for i in range(n_samples):
                    # Save regular model samples
                    img = encoded_im[i].cpu().float() / 255.0  # Normalize to [0,1]
                    torchvision.utils.save_image(img, os.path.join(path_samples, f'sample_{i}.png'))
                    
                    # Save EMA model samples
                    img_ema = new_data_ema[i].cpu().float() / 255.0  # Normalize to [0,1]
                    torchvision.utils.save_image(img_ema, os.path.join(path_samples_ema, f'sample_{i}.png'))
                
                # Create grid images of all samples for easier visualization
                if rank == 0 and wandb_enabled:
                    grid = torchvision.utils.make_grid(encoded_im.cpu().float() / 255.0, nrow=3)
                    grid_ema = torchvision.utils.make_grid(new_data_ema.cpu().float() / 255.0, nrow=3)
                    
                    # Log to wandb
                    wandb.log({
                        "samples": wandb.Image(grid),
                        "samples_ema": wandb.Image(grid_ema),
                        "training_iter": training_iter
                    })
                    
                    # Also save the grids as images
                    torchvision.utils.save_image(grid, os.path.join(path, 'grid.png'))
                    torchvision.utils.save_image(grid_ema, os.path.join(path, 'grid_ema.png'))
                
                model.train()

    if rank == 0:
        save_ckpt(model, ema, opt, scheduler, os.path.join(opts.dir, 'final_checkpoint.pt'))

    dist.barrier()
    if wandb_enabled:
        wandb.finish()
    dist.destroy_process_group()

def load_checkpoint(opts, rank, device, model, ema, opt, scheduler):
    print(f'Loading checkpoint from {opts.load_checkpoint} in rank {rank}')
    snapshot = torch.load(os.path.join(opts.load_checkpoint), weights_only=True)
    model.net.load_state_dict(snapshot['model'],strict=False)
    ema.net.load_state_dict(snapshot['ema'],strict=True)
    opt.load_state_dict(snapshot['optimizer'])
    scheduler.load_state_dict(snapshot['scheduler'])
        
    start_iter = scheduler.last_epoch
    return start_iter

def save_ckpt(model, ema, opt, scheduler, path):
    snapshot = {
                    'model': model.module.net.state_dict(),
                    'ema': ema.net.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict()
                }
    torch.save(snapshot,path)

        
if __name__ == '__main__':
    training()